CS236781: Deep Learning on Computational Accelerators¶

Homework Assignment 3¶

Faculty of Computer Science, Technion.

Submitted by:

# Name Id email
Student 1 [Jamal Tannous] [208912337] [jamaltannous@campus.technion.ac.il]
Student 2 [Snir Hordan] [205689581] [snirhordan@campus.technion.ac.il]

Introduction¶

In this assignment we'll create a from-scratch implementation of two fundemental deep learning concepts: the backpropagation algorithm and stochastic gradient descent-based optimizers. Following that, we'll focus on sequences, and learn to generate text with a deep multilayer RNN network based on GRU cells.

General Guidelines¶

  • Please read the getting started page on the course website. It explains how to setup, run and submit the assignment.
  • Please read the course servers usage guide. It explains how to use and run your code on the course servers to benefit from training with GPUs.
  • The text and code cells in these notebooks are intended to guide you through the assignment and help you verify your solutions. The notebooks do not need to be edited at all (unless you wish to play around). The only exception is to fill your name(s) in the above cell before submission. Please do not remove sections or change the order of any cells.
  • All your code (and even answers to questions) should be written in the files within the python package corresponding the assignment number (hw1, hw2, etc). You can of course use any editor or IDE to work on these files.

Contents¶

  • Part 1: Sequence Models
  • Part 2: Variational Autoencoder
  • Part 3: Generative Adversarial Networks
$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bb}[1]{\boldsymbol{#1}} $$

Part 1: Sequence Models¶

In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!

In [53]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [54]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Text generation with a char-level RNN¶

Obtaining the corpus¶

Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.

In [55]:
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')

def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
    pathlib.Path(out_path).mkdir(exist_ok=True)
    out_filename = os.path.join(out_path, os.path.basename(url))
    
    if os.path.isfile(out_filename) and not force:
        print(f'Corpus file {out_filename} exists, skipping download.')
    else:
        print(f'Downloading {url}...')
        with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
            shutil.copyfileobj(response, out_file)
        print(f'Saved to {out_filename}.')
    return out_filename
    
corpus_path = download_corpus()
Corpus file /home/snirhordan/.pytorch-datasets/shakespeare.txt exists, skipping download.

Load the text into memory and print a snippet:

In [56]:
with open(corpus_path, 'r', encoding='utf-8') as f:
    corpus = f.read()

print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL

by William Shakespeare

Dramatis Personae

  KING OF FRANCE
  THE DUKE OF FLORENCE
  BERTRAM, Count of Rousillon
  LAFEU, an old lord
  PAROLLES, a follower of Bertram
  TWO FRENCH LORDS, serving with Bertram

  STEWARD, Servant to the Countess of Rousillon
  LAVACHE, a clown and Servant to the Countess of Rousillon
  A PAGE, Servant to the Countess of Rousillon

  COUNTESS OF ROUSILLON, mother to Bertram
  HELENA, a gentlewoman protected by the Countess
  A WIDOW OF FLORENCE.
  DIANA, daughter to the Widow

  VIOLENTA, neighbour and friend to the Widow
  MARIANA, neighbour and friend to the Widow

  Lords, Officers, Soldiers, etc., French and Florentine  

SCENE:
Rousillon; Paris; Florence; Marseilles

ACT I. SCENE 1.
Rousillon. The COUNT'S palace

Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black

  COUNTESS. In delivering my son from me, I bury a second husband.
  BERTRAM. And I in going, madam, weep o'er my father's death anew;
    but I must attend his Majesty's command, to whom I am now in
    ward, evermore in subjection.
  LAFEU. You shall find of the King a husband, madam; you, sir, a
    father. He that so generally is at all times good must of
    

Data Preprocessing¶

The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.

TODO: Implement the char_maps() function in the hw3/charnn.py module.

In [57]:
import hw3.charnn as charnn

char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)

test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'f': 0, '6': 1, 'I': 2, 'W': 3, 'X': 4, 'B': 5, ']': 6, 'C': 7, 'D': 8, '"': 9, '&': 10, '[': 11, ':': 12, 'n': 13, 'i': 14, 't': 15, '\n': 16, 'c': 17, 'F': 18, 'm': 19, 'g': 20, 'e': 21, 'A': 22, 'h': 23, '}': 24, 'v': 25, 'x': 26, '(': 27, 'w': 28, ';': 29, 'J': 30, 'K': 31, '!': 32, '_': 33, '4': 34, 'Y': 35, '1': 36, 'T': 37, 'R': 38, 'u': 39, 'o': 40, ' ': 41, 'U': 42, 'M': 43, ')': 44, 'l': 45, '8': 46, ',': 47, 'd': 48, 'L': 49, '$': 50, 'S': 51, 'b': 52, 'r': 53, 'y': 54, 'q': 55, '3': 56, 'Z': 57, '.': 58, '\ufeff': 59, '7': 60, '5': 61, '-': 62, "'": 63, '<': 64, 'z': 65, 'H': 66, 'V': 67, 'Q': 68, 'G': 69, 'j': 70, 'O': 71, '?': 72, 'k': 73, 'E': 74, '0': 75, '9': 76, 'N': 77, 'a': 78, 'P': 79, 'p': 80, '2': 81, 's': 82}

Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.

TODO: Implement the remove_chars() function in the hw3/charnn.py module.

In [58]:
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')

# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 34 chars

The next thing we need is an embedding of the chracters. An embedding is a representation of each token from the sequence as a tensor. For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index corresponding to that specific char.

TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.

In [59]:
# Wrap the actual embedding functions for calling convenience
def embed(text):
    return charnn.chars_to_onehot(text, char_to_idx)

def unembed(embedding):
    return charnn.onehot_to_chars(embedding, idx_to_char)

text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))

test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
   
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0]], dtype=torch.int8)

Dataset Creation¶

We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.

We will split our corpus into shorter sequences of length S chars (see question below). Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence. For each sample, we'll also need a label. This is simply another sequence, shifted by one char so that the label of each char is the next char in the corpus.

TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.

In [60]:
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)

# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')

# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))

# Test content
for _ in range(1000):
    # random sample
    i = np.random.randint(num_samples, size=(1,))[0]
    # Compare to corpus
    test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
    # Compare to labels
    sample_text = unembed(samples[i])
    label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
    test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
samples shape: torch.Size([99182, 64, 78])
labels shape: torch.Size([99182, 64])

Let's print a few consecutive samples. You should see that the text continues between them.

In [61]:
import re
import random

i = random.randrange(num_samples-5)
for i in range(i, i+5):
    test.assertEqual(len(samples[i]), seq_len)
    s = re.sub(r'\s+', ' ', unembed(samples[i])).strip()
    print(f'sample [{i}]:\n\t{s}')
sample [7649]:
	o be ballast at her nose. ANTIPHOLUS OF SYRACUSE. Where stood
sample [7650]:
	Belgia, the Netherlands? DROMIO OF SYRACUSE. O, Sir, I did not l
sample [7651]:
	ook so low. To conclude: this drudge or diviner laid claim to
sample [7652]:
	me; call'd me Dromio; swore I was assur'd to her; told me what
sample [7653]:
	privy marks I had about me, as, the mark of my shoulder, the

As usual, instead of feeding one sample at a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.

An important nuance is that we need the batches to be contiguous, i.e. sample $k$ in batch $j$ should continue sample $k$ from batch $j-1$. The following figure illustrates this:

If we naïvely take consecutive samples into batches, e.g. [0,1,...,B-1], [B,B+1,...,2B-1] and so on, we won't have contiguous sequences at the same index between adjacent batches.

To accomplish this we need to tell our DataLoader which samples to combine together into one batch. We do this by implementing a custom PyTorch Sampler, and providing it to our DataLoader.

TODO: Implement the SequenceBatchSampler class in the hw3/charnn.py module.

In [62]:
from hw3.charnn import SequenceBatchSampler

sampler = SequenceBatchSampler(dataset=range(32), batch_size=10)
sampler_idx = list(sampler)
print('sampler_idx =\n', sampler_idx)

# Test the Sampler
test.assertEqual(len(sampler_idx), 30)
batch_idx = np.array(sampler_idx).reshape(-1, 10)
for k in range(10):
    test.assertEqual(np.diff(batch_idx[:, k], n=2).item(), 0)
sampler_idx =
 [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29]

Even though we're working with sequences, we can still use the standard PyTorch Dataset/DataLoader combo. For the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label) from the samples and labels tensors we created above. The DataLoader will be provided with our custom Sampler so that it generates appropriate batches.

In [63]:
import torch.utils.data

# Create DataLoader returning batches of samples.
batch_size = 32

ds_corpus = torch.utils.data.TensorDataset(samples, labels)
sampler_corpus = SequenceBatchSampler(ds_corpus, batch_size)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, sampler=sampler_corpus, shuffle=False)

Let's see what that gives us:

In [64]:
print(f'num batches: {len(dl_corpus)}')

x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch of samples: {x0.shape}')
print(f'shape of a batch of labels: {y0.shape}')
num batches: 3100
shape of a batch of samples: torch.Size([32, 64, 78])
shape of a batch of labels: torch.Size([32, 64])

Now lets look at the same sample index from multiple batches taken from our corpus.

In [65]:
# Check that sentences in in same index of different batches complete each other.
k = random.randrange(batch_size)
for j, (X, y) in enumerate(dl_corpus,):
    print(f'=== batch {j}, sample {k} ({X[k].shape}): ===')
    s = re.sub(r'\s+', ' ', unembed(X[k])).strip()
    print(f'\t{s}')
    if j==4: break
=== batch 0, sample 3 (torch.Size([64, 78])): ===
	defective for requital Than we to stretch it out. Masters o
=== batch 1, sample 3 (torch.Size([64, 78])): ===
	' th' people, We do request your kindest ears; and, after,
=== batch 2, sample 3 (torch.Size([64, 78])): ===
	Your loving motion toward the common body, To yield wha
=== batch 3, sample 3 (torch.Size([64, 78])): ===
	t passes here. SICINIUS. We are convented Upon a pleasing
=== batch 4, sample 3 (torch.Size([64, 78])): ===
	treaty, and have hearts Inclinable to honour and advance

Model Implementation¶

Finally, our data set is ready so we can focus on our model.

We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.

The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.

Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as

$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$

The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}

\end{bmatrix} ¶

\begin{cases} \mat{X} & \mathrm{if} ~k = 1~ \\ \mathrm{dropout}_p \left( \begin{bmatrix} {\vec{h}_1}^{[k-1]} \\ \vdots \\ {\vec{h}_S}^{[k-1]} \end{bmatrix} \right) & \mathrm{if} ~1 < k \leq L+1~ \end{cases}

. $$

The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$

and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$

Notes:

  • $t\in[1,S]$ is the timestep, i.e. the current position within the sequence of each sample.
  • $\vec{x}_t^{[k]}$ is the input of layer $k$ at timestep $t$, respectively.
  • The outputs of the last layer $\vec{y}_t^{[L]}$, are the predicted next characters for every input char. These are similar to class scores in classification tasks.
  • The hidden states at the last timestep, $\vec{h}_S^{[k]}$, are the final hidden state returned from the model.
  • $\sigma(\cdot)$ is the sigmoid function, i.e. $\sigma(\vec{z}) = 1/(1+e^{-\vec{z}})$ which returns values in $(0,1)$.
  • $\tanh(\cdot)$ is the hyperbolic tangent, i.e. $\tanh(\vec{z}) = (e^{2\vec{z}}-1)/(e^{2\vec{z}}+1)$ which returns values in $(-1,1)$.
  • $\vec{h_t}^{[k]}$ is the hidden state of layer $k$ at time $t$. This can be thought of as the memory of that layer.
  • $\vec{g_t}^{[k]}$ is the candidate hidden state for time $t+1$.
  • $\vec{z_t}^{[k]}$ is known as the update gate. It combines the previous state with the input to determine how much the current state will be combined with the new candidate state. For example, if $\vec{z_t}^{[k]}=\vec{1}$ then the current input has no effect on the output.
  • $\vec{r_t}^{[k]}$ is known as the reset gate. It combines the previous state with the input to determine how much of the previous state will affect the current state candidate. For example if $\vec{r_t}^{[k]}=\vec{0}$ the previous state has no effect on the current candidate state.

Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).

Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (i.e., on the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.

TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.

Notes:

  • You'll need to handle input batches now. The math is identical to the above, but all the tensors will have an extra batch dimension as their first dimension.
  • Use the diagram above to help guide your implementation. It will help you visualize what shapes to returns where, etc.
In [66]:
in_dim = vocab_len
h_dim = 256
n_layers = 3
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)

# Test forward pass
y, h = model(x0.to(dtype=torch.float, device=device))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')

test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2) 
MultilayerGRU(
  (Layer_0_xz): Linear(in_features=78, out_features=256, bias=True)
  (Layer_0_xr): Linear(in_features=78, out_features=256, bias=True)
  (Layer_0_xg): Linear(in_features=78, out_features=256, bias=True)
  (Layer_0_hz): Linear(in_features=256, out_features=256, bias=False)
  (Layer_0_hr): Linear(in_features=256, out_features=256, bias=False)
  (Layer_0_hg): Linear(in_features=256, out_features=256, bias=False)
  (Layer_0_dropout): Dropout(p=0, inplace=False)
  (Layer_1_xz): Linear(in_features=256, out_features=256, bias=True)
  (Layer_1_xr): Linear(in_features=256, out_features=256, bias=True)
  (Layer_1_xg): Linear(in_features=256, out_features=256, bias=True)
  (Layer_1_hz): Linear(in_features=256, out_features=256, bias=False)
  (Layer_1_hr): Linear(in_features=256, out_features=256, bias=False)
  (Layer_1_hg): Linear(in_features=256, out_features=256, bias=False)
  (Layer_1_dropout): Dropout(p=0, inplace=False)
  (Layer_2_xz): Linear(in_features=256, out_features=256, bias=True)
  (Layer_2_xr): Linear(in_features=256, out_features=256, bias=True)
  (Layer_2_xg): Linear(in_features=256, out_features=256, bias=True)
  (Layer_2_hz): Linear(in_features=256, out_features=256, bias=False)
  (Layer_2_hr): Linear(in_features=256, out_features=256, bias=False)
  (Layer_2_hg): Linear(in_features=256, out_features=256, bias=False)
  (Layer_2_dropout): Dropout(p=0, inplace=False)
  (Output_layer): Linear(in_features=256, out_features=78, bias=True)
)
y.shape=torch.Size([32, 64, 78])
h.shape=torch.Size([32, 3, 256])

Generating text by sampling¶

Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t, \vec{h}_t).$$

Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.

The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.

To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$

A low $T$ will result in less uniform distributions and vice-versa.

TODO: Implement the hot_softmax() function in the hw3/charnn.py module.

In [67]:
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))

for t in reversed([0.3, 0.5, 1.0, 100]):
    ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()

uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))

TODO: Implement the generate_from_model() function in the hw3/charnn.py module.

In [68]:
for _ in range(3):
    text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
    print(text)
    test.assertEqual(len(text), 50)
foobarNXnWT3M3-4R[Ha.D'KlZ8l'h96sahFb&NwG5zt:MqM-w
foobarSh7NuDY zLozpxU:SJI5qA9[ffjWk0kx1bUj4KLoM.Ln
foobarl9qNsE8&e3t.-'03K:9RN]Kb[ 1CktB98;Yd)L-YoQlm

Training¶

To train this model, we'll calculate the loss at each time step by comparing the predicted char to the actual char from our label. We can use cross entropy since per char it's similar to a classification problem. We'll then sum the losses over the sequence and back-propagate the gradients though time. Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times, so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.

As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.

For a generative model such as this, overfitting is slightly trickier than for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.

Let's create a tiny dataset to memorize.

In [69]:
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
batch_size_ss = 1
sampler_ss = SequenceBatchSampler(ds_corpus_ss, batch_size=batch_size_ss)
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size_ss, sampler=sampler_ss, shuffle=False)

# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
    subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":

TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

Now let's implement the first part of our training code.

TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module. You must think about how to correctly handle the hidden state of the model between batches and epochs for this specific task (i.e. text generation).

In [70]:
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer

torch.manual_seed(42)

lr = 0.01
num_epochs = 500

in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)

for epoch in range(num_epochs):
    epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
    
    # Every X epochs, we'll generate a sequence starting from the first char in the first sequence
    # to visualize how/if/what the model is learning.
    if epoch == 0 or (epoch+1) % 25 == 0:
        avg_loss = np.mean(epoch_result.losses)
        accuracy = np.mean(epoch_result.accuracy)
        print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
        
        generated_sequence = charnn.generate_from_model(model, subset_text[0],
                                                        seq_len*(subset_end-subset_start),
                                                        (char_to_idx,idx_to_char), T=0.1)
        
        # Stop if we've successfully memorized the small dataset.
        print(generated_sequence)
        if generated_sequence == subset_text:
            break

# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.843, Accuracy = 17.58%
Tdt                                                                                                                                                                                                                                                             

Epoch #25: Avg. loss = 0.034, Accuracy = 99.61%
TRAM. What would you have?
  HELENA. Something; and scarce so much; nothing, indeed.
    I would not tell you what I would, my lord.
    Faith, yes:
    Strangers and foes do sunder and not kiss.
  BERTRAM. I pray you, stay not, but in haste to horse.
  HE

OK, so training works - we can memorize a short sequence. We'll now train a much larger model on our large dataset. You'll need a GPU for this part.

First, lets set up our dataset and models for training. We'll split our corpus into 90% train and 10% test-set. Also, we'll use a learning-rate scheduler to control the learning rate during training.

TODO: Set the hyperparameters in the part1_rnn_hyperparams() function of the hw3/answers.py module.

In [71]:
from hw3.answers import part1_rnn_hyperparams

hp = part1_rnn_hyperparams()
print('hyperparams:\n', hp)

### Dataset definition
vocab_len = len(char_to_idx)
batch_size = hp['batch_size']
seq_len = hp['seq_len']
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)

samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)

ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
sampler_train = SequenceBatchSampler(ds_train, batch_size)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size, shuffle=False, sampler=sampler_train, drop_last=True)

ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
sampler_test = SequenceBatchSampler(ds_test, batch_size)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size, shuffle=False, sampler=sampler_test, drop_last=True)

print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test:  {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')

### Training definition
in_dim = out_dim = vocab_len
checkpoint_file = 'checkpoints/rnn'
num_epochs = 50
early_stopping = 5

model = charnn.MultilayerGRU(in_dim, hp['h_dim'], out_dim, hp['n_layers'], hp['dropout'])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=hp['learn_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=hp['lr_sched_factor'], patience=hp['lr_sched_patience'], verbose=True
)
trainer = RNNTrainer(model, loss_fn, optimizer, device)
hyperparams:
 {'batch_size': 250, 'seq_len': 100, 'h_dim': 250, 'n_layers': 2, 'dropout': 0.01, 'learn_rate': 0.005, 'lr_sched_factor': 0.05, 'lr_sched_patience': 1}
Train: 228 batches, 5700000 chars
Test:   25 batches,  625000 chars

The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.

Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.

TODO:

  • Implement the fit() method of the Trainer class. You can reuse the relevant implementation parts from HW2, but make sure to implement early stopping and checkpoints.
  • Implement the test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.
  • Run the following block to train.
  • When training is done and you're satisfied with the model's outputs, rename the checkpoint file to checkpoints/rnn_final.pt. This will cause the block to skip training and instead load your saved model when running the homework submission script. Note that your submission zip file will not include the checkpoint file. This is OK.
In [72]:
from cs236781.plot import plot_fit

def post_epoch_fn(epoch, train_res, test_res, verbose):
    # Update learning rate
    scheduler.step(test_res.accuracy)
    # Sample from model to show progress
    if verbose:
        start_seq = "ACT I."
        generated_sequence = charnn.generate_from_model(
            model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
        )
        print(generated_sequence)

# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    saved_state = torch.load(checkpoint_file_final, map_location=device)
    model.load_state_dict(saved_state['model_state'])
else:
    try:
        # Print pre-training sampling
        print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))

        fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=None,
                              post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
                              checkpoints=checkpoint_file, print_every=1)
        
        fig, axes = plot_fit(fit_res)
    except KeyboardInterrupt as e:
        print('\n *** Training interrupted by user')
ACT I.w?cz
iX?'E64CVELCDz,Yv?K?f"F),)jS HyCm0M3CcX?]LtA?Tr[s,ELIz,Fk8:IvW)?"pIJ1?8??)b!ATqbj2hYC8cd?
--- EPOCH 1/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 1
ACT I.
  PRINCE. I will be not have mether, that the stand the fine
    That make me for our singero
--- EPOCH 2/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 2
ACT I.
    And shall be in him be but him will be they from
    the song of the more commontry of th
--- EPOCH 3/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 3
ACT I.                                                                                              
--- EPOCH 4/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 4
ACT I. He should be gone.
                                                                          
--- EPOCH 5/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 5
ACT I.
  Ham. I have the stand be in my brother, and make your persons.
    The man hath been a very
--- EPOCH 6/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 6
ACT I.
  Leon. Make me good morrow, good night; I will not be a brave fingers.  
  Most will. I will
--- EPOCH 7/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 7
ACT I. I say the King is now
    and every truth, thy soul the great particular man's son's person,

--- EPOCH 8/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 8
ACT I. There's now the soul
    be commended to show the subject of the wars of your spirits,
    an
--- EPOCH 9/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 9
ACT I. Give me the truth,
    And be of grace with such a house of the world,
    When she is but th
--- EPOCH 10/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 10
ACT I. Well, we are more than the
    good will be necessary of the humour of the music with his sha
--- EPOCH 11/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 11
ACT I.
  Bora. O, my wife hath nothing live and bear a fear
    his blood than my mistress' charge.

--- EPOCH 12/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 12
ACT I.
  Bene. Sir, she shall see them slain.
  Bene. If thou be shouldst thou had been a most princ
--- EPOCH 13/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 13
ACT I.
  Ham. I will return the world when he of these thoughts.
  Hor. I speak with thee to me.
  H
--- EPOCH 14/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 14
ACT I.
  Bene. I know the more will I shall be but the rest of the stubborn.
                       
--- EPOCH 15/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 15
ACT I.
  Hot. I will not well for you. I have done the way
    I would not have the mean to the prop
--- EPOCH 16/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 16
Epoch    16: reducing learning rate of group 0 to 2.5000e-04.
ACT I.
  PORTIA. I have not not so much to follow my body.
    But, what most heavenly part of the d
--- EPOCH 17/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 17
ACT I.
  SIR TOBY. And the most contrary knows the cutternest man, and it was
    as little in the n
--- EPOCH 18/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 18
ACT I.                   Exit.
                                                                     
--- EPOCH 19/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 19
ACT I. The fairest corn
    I was more than the head, and the one of the sword
    To seek their hea
--- EPOCH 20/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 20
ACT I.
  DUKE SENIOR. Then were the time of this way speaks to any
    And then to the gods for the 
--- EPOCH 21/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 21
ACT I.  [To Borachio.]
  Fal. Why, then a man have the world of the rock of thy sight.
  Fal. What a
--- EPOCH 22/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 22
ACT I.
    The greatest strength of this great calpable
    That I may stand in thee for that with h
--- EPOCH 23/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 23
ACT I.
  LAFEU. And then he was a king of many true love mine honour.
    I mean, the which she hath
--- EPOCH 24/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 24
ACT I.
    The more that was the grave of the wars of him.
  DUKE. A man of good action, and poor th
--- EPOCH 25/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 25
ACT I.
                                                                                Exit [Drum an
--- EPOCH 26/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 26
ACT I.
    O Caesar! Why, the world is distracted to him.
    What say you to you and a flower?
  CA
--- EPOCH 27/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 27
ACT I.
    I have done me with the like deserves of this.
                                          
--- EPOCH 28/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 28
ACT I.
  CADE. I am a mile as they say; and so I did not so heart.
                                 
--- EPOCH 29/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 29
ACT I.
  COSTARD. This is the world mad man.
  PANDARUS. What is it thus far for the name of Sicilia
--- EPOCH 30/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 30
ACT I.
    The King, what name is the rest of this place?
    What then? What news are grown on them
--- EPOCH 31/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 31
ACT I.
  Mer. The other for a white world is like to be the man
    and the land of a head of form. 
--- EPOCH 32/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 32
ACT I.
                                                         Exit.
                              
--- EPOCH 33/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 33
ACT I.
  Prince. I shall have the skirt and letters of the sound of the
    gentleman of a stranger 
--- EPOCH 34/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 34
ACT I.                                 Exit.

SCENE II.
The carping of the Emperor

Enter HOSTESS an
--- EPOCH 35/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 35
ACT I.
  CAPHIS. I told you there! I have put on him that he would
    do not be his own lady. Have 
--- EPOCH 36/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 36
ACT I.
  GONZALO. He hath a thing is the great prince.
  VIOLA. I have seen the breaking of the fait
--- EPOCH 37/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 37
ACT I.
  Leon. So do I think the Duke of England, rot a word.
    I would thou speak the point of su
--- EPOCH 38/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 38
ACT I.
    What was he presently?
  CAPTAIN. The Duke of Norfolk, if thou shalt not see him;
    I w
--- EPOCH 39/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 39
ACT I.
  CLOWN. He is my lord's in the devil to the King.
  CRESSIDA. I will not know the house when
--- EPOCH 40/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 40
Epoch    40: reducing learning rate of group 0 to 1.2500e-05.
ACT I.
    What, what a fat fool speaks to me?
    The heavens did solicit you, sir.
    You are a c
--- EPOCH 41/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 41
ACT I.
    I shall have some mind that I am too much enough.
    I will not take the prince excuse o
--- EPOCH 42/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 42
ACT I.
    What says the meaning of the court? There's some further
    device?
  Leon. What's the m
--- EPOCH 43/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 43
ACT I.  
  CADE. A man as I have no common single princely part of
    the prince. What then?
  SHAL
--- EPOCH 44/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 44
Epoch    44: reducing learning rate of group 0 to 6.2500e-07.
ACT I.
                                                                      Exeunt.

Scene IV.
Lesi
--- EPOCH 45/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 45
ACT I.
    The fairies was a time to be the word.
  ACHILLES. What think'st thou, my lord?
  MARCUS.
--- EPOCH 46/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 46
Epoch    46: reducing learning rate of group 0 to 3.1250e-08.
ACT I.
  CRESSIDA. He that doth with the parts of his and least behaviour.
    I do beseech you all 
--- EPOCH 47/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 47
ACT I.
  COSTARD. I will provide thee not a little.
  PANDARUS. Good morrow, sir, the better.
  PARO
--- EPOCH 48/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 48
Epoch    48: reducing learning rate of group 0 to 1.5625e-09.
ACT I.
    Come, come, let's go.
                                                                   
--- EPOCH 49/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 49
ACT I.
    What, how shall I be commanded?
  GONZALO. What are you?
  PETRUCHIO. What say'st thou, T
--- EPOCH 50/50 ---
train_batch:   0%|          | 0/228 [00:00<?, ?it/s]
test_batch:   0%|          | 0/25 [00:00<?, ?it/s]

Generating a work of art¶

Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.

The text you generate should “look” like a Shakespeare play: old-style English words and sentence structure, directions for the actors (like “Exit/Enter”), sections (Act I/Scene III) etc. There will be no coherent plot of course, but it should at least seem like a Shakespearean play when not looking too closely. If this is not what you see, go back, debug and/or and re-train.

TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.

In [73]:
from hw3.answers import part1_generation_params

start_seq, temperature = part1_generation_params()

generated_sequence = charnn.generate_from_model(
    model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)

print(generated_sequence)
When she was just a girly that
      her hands and three of my life, and such a devil. I should not say
    the hand of all the shooting state when they have found
      to the blood. I am a good love, where he is not their voices.
                                                             Exeunt.

Scene II.
A mother of the house of Gloucester's castle.

Enter Polonius, and Shallow. Exit [Laurence] and Cordelia, and Fortune's command].
                                                Enter Montano, and Saint Lear,
                                                                                                      Enter Cordelia.

  Greg. Be thou the fall of heaven!
                                                             Exeunt.

Scene III.
A court and Bardolph.

Enter Paris to the Parloguus, and Lear and Claudio.

                                Enter Kent.

  Osw. What is your will?
     I will not hear the body of his service.
                                                                                       Exeunt.

Scene II.
A hollow within the King's nephew so strange.

Enter Don Pedro and Claudio.

  Fal. I have a learned writ o' th' market-place, and thou wert
    the form of her three, out of the world.
  Prince. Why, she lies one that stands upon the court of the trumpet.
                                                           [Exit Peter.]
    For then, my lord, and then I have made a prince
    and be your father's foot to be the lady that have made
    your company of many with all the world is she in the book
    of the soul, and so are the world of the way to starve and honest man.
  Beat. No.
  Ham. I will not be found and wise, and so well the poor roaring of a
    glorious end.                                                                    Exeunt.

Scene III.
Another part of the field and the King

  Pol. Where have I so?
  Fal. Go to! Marry, my lord, my lord.
  Prince. Who is the matter of our part?
  Ham. So did I remember thee again. I thank you and my lord to the act.
  Beat. I would prove the water-tomb bound to me be the services of the
    court of the bottom of the manner of the best of the maids of
    the sun with two sons, and his hands with the gates of his
    villain. I have prizes me that the truth is dead, and let him see the  
    hand, and so be one. I am sure I believe the best of the present of
    the poor opinions. Farewell. I will be with me to her.
  Prince. Well, let me see my lord.
  Peto. What is the matter?
  Bene. What says he?
  Pedro. Good Master Constable, and the more that I have said with the
    strength of the mouth. He hath been consul, that thou art
    so perfected and hath held my love; and he is to be my
    reply.
  Pedro. What a perfecter stand the man, you shall see you?
  Bene. I do not draw the love of your command.
  Beat. I will not answer to him. I have a bound and so hated a
    great man.
  Pedro. Why, then I am sure so swear thee.
  Pedro. That is a good angel that I had the better than the
    virtue of the contrary. What is your will?
  Pedro. I will not prove upon you.
  Bene. They say the time is the same instrument of the street in his
    court. The fool will let them prove a man as they say; and there
    was not a man when the gates of this same sun that she that
    seem to be so valiant in the banks of my lord and my will.
  Petruchio. He is very well said.
  Prince. From the King, my lord, this day will be a man he will be the
    hand.
  Fal. I am sworn attendants, and the fair and the letter in the streets of
    hard soldiers that they are sent for that we heard her beholding
    worthy a father.
  Prince. Here is the basket in his house. I will not speak with the book
    of the town before the lady.
  Prince. What, have you a great strife I would have me call your worship in this
    sovereign?
  Prince. I have done to the Senate have I should not shake him to you. I
    will have the witness of it. If thou be a thousand thanks to
    have my soul to the first for the throat, and the secrets of
    it will be the day with the gods before the field and finds the stream
    more than the most accusation of my mother, and he cannot
    be old and come to be bound to my will.
  Pedro. Then the world is more than the senators of the hands
    of the three-pound. So do you so hard than a natural saint in the
    wife, and the streets of the best of the worst that will have the
    camp upon our proper sport, and the manner of the maids of his
    discovery that is the gates of the master-souls of our means to  
    the gods to counterfeit the army of the common power
    and see the tongue of the air. I know the table and come off a man of
    the service, and a cardinal of my bosom, they are so.
                                                            Exeunt.

Scene III.
Stanley.

Enter The Palace.

  Ham. This is the world to tell you himself and good night
    I know the matter of the gates of thy face.
    For that the speediest soldiers hath a thing
    To such a service to his own desires,
    And bear me up and fighting to desire.
    And there is this submission of his bosom,
    I am a man of heaven in the remembrance
    And broke our father's flame to travel me.
    Madam, I will not speak with me at friend.
  Rom. So much in me a stranger with the bosom
    With second time with a bark to the cause.
    And to each other shall make haste before;
    And what there is no matter for this court?
    Who is not what thou dost before me?
    The bridegroom of his death to me into the state,
    With fair and rash and dangerous confession.
    What means this to a death and poor and faces?
  Ben. The part of the strong isle sits in his part,
    And in the boy of Him that should have seen.
    The same of them that is a word or stark.
    I love thee in a stranger, that the world
    Were such a better summer lives in heaven.
    Romeo will be the better that were there,
    And I will find a mountain town and letters
    The sun with thoughts of such a sea to thee.
    I see the court I took the ground.
    I hope the letter, if thou wert a bloody
    To make thee like a father's wife with this,
    To see them call me truth. The realm so can
    That they desire to see the world in state
    But that the sun are then have found my presence.
    Let us not be a prince's body and fair earth,
    And this the bed that men have been to read
    The first of their officers with my country,
    And will she say the news, and then a woman's death.
    My son is sound and straight from them to see.
    Without a power to be my heart before,
    And then thou hast the fairest prosperous gods,
    And God help the other of the devil.
    What consequence, the foul soul of the crown?
    I think some constant state and life to say 'I was,
    Did then the belly book on him at home,
    Which they that show'd his father with their blood,
    As this should stay with all the rage of war.
    All loves and good contempt, stand as the senators,
    The sun of your brother being here a father
    That it is now a child shall think the King.
  Alb. Set thy hands and shallow it that I shall shake.
                                                                   [Exit Tom.
                                                     Exeunt.

Scene II.
Sorversete of the castle.

Enter Messengers, Angelo, and Beatrice.

  Friar. Let no more be the lady of the King, and so conduct
    the great company.
  Ham. I will do it.
  Pol. I know not, good my lord.
  Bene. I would the fine hand here to be done.
  Ham. Ay, for he hath forgot the man of the cause of charity!
    What, will you come to thy charge? He is a woman so.
    God save you, sir, he hath his beauty to be bound.
  Ham. Ay, and see it in my land.
                                                                      Exeunt.

Scene III.
Elsinore. A street

Enter Edgar.

  Lear. And there is no man so betrayed as the state
      With the contents of the commons of the King,
                     That she shall be so far.
                                      Exeunt.

Scene II.
Elsinore. A street

Enter Antonio and Soldiers, and others doth the King.

  King. Why, then I shall be so best love.
  Fran. Go tall what we conceive the name of this fac'd
    honourable and destruction. This man is a punishment,
    so in this shepherd, and to stand and fight at my own soul
    Is not so true as they are set down a care.
  Ham. No, thou art angry.
  Hot. The same state shall be so, he was not like a man.
                                                                                Exit.
  Gon. My wife were nothing entertainment, but I
     than merry than the gods are at his head. If thou be fair, and
     should be so well as much as the sun will make the compass of the
     orchard for the door with the chain of his accusations. I must confess
     the law to strike the late fingers of the particular, and the King's poor
     three or five of mine own proper man.
  Edm. I know not what I have made the word of the proof.
  Lear. What's the man of the part? Why, the best and such a thousand phease
     that he is a rover- for a story of a devil and the dog
     that they say.
  Osr. You are a man.
  Claud. I know not. I will be in his writing to my wife and make it off
    the world in the streets. I am a beast of it in the sun that you shall
    think this will see you and the priest of the maids of a fair
    country.
  Pedro. I am glad to speak with you. If I love you, sir; so I do not know
    the sea to the city.
  Ham. I would I do not say the gods stand for that your part.
  Ham. What can you proceed in her father.
  Bene. This is not with a world in the matter.
  Prince. Are you not a poor and the truth with the least of the town? What
    thing is the time of his country are the compounded with the robb'd court?
    Good Master Gower, you shine!
  Prince. I am a man, an'

Questions¶

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [1]:
from cs236781.answers import display_answer
import hw3.answers

Question 1¶

Why do we split the corpus into sequences instead of training on the whole text?

In [2]:
display_answer(hw3.answers.part1_q1)

Your answer:

    We have a large corpus, uploading the entire corpus at once onto the machine requires large memory resources, which slows down the training procss. 
    We avoid this phenomenon by breaking the corpus down into small parts and load one part at a time to the memory.

Question 2¶

How is it possible that the generated text clearly shows memory longer than the sequence length?

In [3]:
display_answer(hw3.answers.part1_q2)

Your answer: Memory in RNN's is derived by the hidden state's ability to predict the next word in the sequence itwas trained on. Our network demonstrates a longer memory capability than just the lengths of the sequences we trained it on, because the hidden state learns the interconnections between sequences and genarlizes to te entire corpus.

Question 3¶

Why are we not shuffling the order of batches when training?

In [4]:
display_answer(hw3.answers.part1_q3)

Your answer: We don't mix the order of the batches when training because we want to train the modules in the correct order. Training the modules according to the correct order ensures keeping a correct and logical relationship between the sentences. Additionally, it takes context into accont. This helps our module in generating a text which resembles the original text.

Question 4¶

  1. Why do we lower the temperature for sampling (compared to the default of $1.0$)?
  2. What happens when the temperature is very high and why?
  3. What happens when the temperature is very low and why?
In [6]:
display_answer(hw3.answers.part1_q4)

Your answer:

a. We lower the temprature for the model to make the conditional distribution of the next word givn the current one as dissimilar to uniform distribution as possible. If the distribution were indeed uniform then taking maximum argument as criterion will yield very unpredictable and thus uninformative results.

b. Probability over the output with temparature T defined as $ e^{y_i/T} / \sum{e^{y_i/T}} $ If T is very large than the exponent is very close to 0, then the numerator will be around 1 and denominator around n, then for any output we obtain a distribution similar to uniform distribution.

c. Using a very low temperature means that the variance of the distribution is also small. This means the the model would be very far from a uniform distribution. As a cosequence to that, the generated model would choose only that chars that it's certain about, without taking any risks in choosing other chars. This would yield corpus with a very constrained number of chars, becuase the other chars didn't have a chance of being picked by the model.

In [ ]:
 
In [ ]:
 
$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 2: Variational Autoencoder¶

In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from its latent space. We'll implement and train a VAE and use it to generate new images.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2
In [2]:
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda

Obtaining the dataset¶

Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labeled faces of famous individuals.

We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)

However, if you feel adventurous and/or prefer to generate something else, feel free to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.

In [3]:
import cs236781.plot as plot
import cs236781.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236781.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/snirhordan/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/snirhordan/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/snirhordan/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [4]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [5]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [6]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

The Variational Autoencoder¶

An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a model with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a model with parameters $\bb{\beta}$).

While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.

We define, in Baysean terminology,

  • The prior distribution $p(\bb{Z})$ on points in the latent space.
  • The posterior distribution of points in the latent spaces given a specific instance: $p(\bb{Z}|\bb{X})$.
  • The likelihood distribution of a sample $\bb{X}$ given a latent-space representation: $p(\bb{X}|\bb{Z})$.
  • The evidence distribution $p(\bb{X})$ which is the distribution of the instance space due to the generative process.

To create our variational decoder we'll further specify:

  • A parametric likelihood distribution, $p _{\bb{\beta}}(\bb{X} | \bb{Z}=\bb{z}) = \mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$. The interpretation is that given a latent $\bb{z}$, we map it to a point normally distributed around the point calculated by our decoder neural network. Note that here $\sigma^2$ is a hyperparameter while $\vec{\beta}$ represents the network parameters.
  • A fixed latent-space prior distribution of $p(\bb{Z}) = \mathcal{N}(\bb{0},\bb{I})$.

This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.

Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) = \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder model, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}_{\bb{\alpha}}(\bb{x})$.

To train a VAE model, we maximize the evidence distribution, $p(\bb{X})$ (see question below). The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. Although this expectation is intractable, we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO", shown in the lecture):

$$ \log p(\bb{X}) \ge \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} }\left[ \log p _{\bb{\beta}}(\bb{X} | \bb{z}) \right] - \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{X})\,\left\|\, p(\bb{Z} )\right.\right) $$

where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.

Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]

  • \mathcal{D} {\mathrm{KL}}\left(q {\bb{\alpha}}(\bb{Z} | \bb{x})\,\left|\, p(\bb{Z} )\right.\right) \right]. $$

By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as

$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} _{\bb{x}} \left[ \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 \right] + \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{x})\,\left\|\, p(\bb{Z} )\right.\right) \right]. $$

Model Implementation¶

Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).

First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input.

TODO: Implement the EncoderCNN class in the hw3/autoencoder.py module. Implement any CNN architecture you like. If you need "architecture inspiration" you can see e.g. this or this paper.

In [7]:
import hw3.autoencoder as autoencoder

in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)

h = encoder_cnn(x0)
print(h.shape)

test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
  (cnn): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(128, 512, kernel_size=(5, 5), stride=(2, 2))
    (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
    (10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
)
torch.Size([1, 1024, 5, 5])

Now let's implement the CNN part of the Decoder. Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced by your EncoderCNN and output an image of the same dimensions as the Encoder's input was. This can be a CNN which is like a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc. Consult the documentation of ConvTranspose2D to figure out how to reverse your convolutional layers in terms of input and output dimensions. Note that the decoder doesn't have to be exactly the opposite of the encoder and you can experiment with using a different architecture.

TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.

In [8]:
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)

test.assertEqual(x0.shape, x0r.shape)

# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
  (cnn): Sequential(
    (0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(512, 128, kernel_size=(5, 5), stride=(2, 2))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1))
    (10): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
torch.Size([1, 3, 64, 64])
Out[8]:

Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:

  1. Produce a feature vector $\vec{h}$ from the input image $\vec{x}$.
  2. Use two affine transforms to convert the features into the mean and log-variance of the posterior, i.e. $$ \begin{align}
     \bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
     \log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
    
    \end{align} $$
  3. Use the reparametrization trick to create the latent representation $\vec{z}$.

Notice that we model the log of the variance, not the actual variance. The above formulation is proposed in appendix C of the VAE paper.

TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__().

In [9]:
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)

z, mu, log_sigma2 = vae.encode(x0)

test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)

print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(128, 512, kernel_size=(5, 5), stride=(2, 2))
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
      (10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvTranspose2d(512, 128, kernel_size=(5, 5), stride=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1))
      (10): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (log): Linear(in_features=25600, out_features=2, bias=True)
  (reconstruct): Linear(in_features=2, out_features=25600, bias=True)
  (mu): Linear(in_features=25600, out_features=2, bias=True)
)
mu(x0)=[-0.1058653, 0.08127165], sigma2(x0)=[1.9030428, 0.90913814]

Let's sample some 2d latent representations for an input image x0 and visualize them.

In [10]:
# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
    for i in range(N):
        Z[i], _, _ = vae.encode(x0)
        ax.scatter(*Z[i].cpu().numpy())

# Should be close to the mu/sigma in the previous block above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
sampled mu tensor([-0.0600,  0.0384])
sampled sigma2 tensor([3.6430, 0.7947])

Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:

  1. Produce a feature vector $\tilde{\vec{h}}$ from the latent vector $\vec{z}$ using an affine transform.
  2. Reconstruct an image $\tilde{\vec{x}}$ from $\tilde{\vec{h}}$ using the decoder CNN.

TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module. You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.

In [11]:
x0r = vae.decode(z)

test.assertSequenceEqual(x0r.shape, x0.shape)

Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.

In [12]:
x0r, mu, log_sigma2 = vae(x0)

test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
Out[12]:

Loss Implementation¶

In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:

$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2 d_x} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}), $$

where $d_z$ is the dimension of the latent space, $d_x$ is the dimension of the input and $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$. This pointwise loss is the quantity that we'll compute and minimize with gradient descent. The first term corresponds to the data-reconstruction loss, while the second term corresponds to the KL-divergence loss. Note that the scaling by $d_x$ is not derived from the original loss formula and was added directly to the pointwise loss just to normalize the data term.

TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.

In [13]:
from hw3.autoencoder import vae_loss
torch.manual_seed(42)

def test_vae_loss():
    # Test data
    N, C, H, W = 10, 3, 64, 64 
    z_dim = 32
    x  = torch.randn(N, C, H, W)*2 - 1
    xr = torch.randn(N, C, H, W)*2 - 1
    z_mu = torch.randn(N, z_dim)
    z_log_sigma2 = torch.randn(N, z_dim)
    x_sigma2 = 0.9
    
    loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
    
    test.assertAlmostEqual(loss.item(), 58.3234367, delta=1e-3)
    return loss

test_vae_loss()
Out[13]:
tensor(58.3234)

Sampling¶

The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a isotropic Gaussian prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.

TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.

In [14]:
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)

Training¶

Time to train!

TODO:

  1. Implement the VAETrainer class in the hw3/training.py module. Make sure to implement the checkpoints feature of the Trainer class if you haven't done so already in Part 1.
  2. Tweak the hyperparameters in the part2_vae_hyperparams() function within the hw3/answers.py module.
In [15]:
import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']

# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test  = DataLoader(ds_test,  batch_size, shuffle=True)
im_size = ds_train[0][0].shape

# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)

# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)

# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
    return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)

# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show model and hypers
print(vae)
print(hp)
VAE(
  (features_encoder): EncoderCNN(
    (cnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(128, 512, kernel_size=(5, 5), stride=(2, 2))
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      (10): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
    )
  )
  (features_decoder): DecoderCNN(
    (cnn): Sequential(
      (0): ConvTranspose2d(512, 512, kernel_size=(5, 5), stride=(2, 2))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvTranspose2d(512, 128, kernel_size=(5, 5), stride=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1))
      (10): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (log): Linear(in_features=12800, out_features=256, bias=True)
  (reconstruct): Linear(in_features=256, out_features=12800, bias=True)
  (mu): Linear(in_features=12800, out_features=256, bias=True)
)
{'batch_size': 32, 'h_dim': 512, 'z_dim': 256, 'x_sigma2': 0.00095, 'learn_rate': 9e-05, 'betas': (0.99, 0.998)}

TODO:

  1. Run the following block to train. It will sample some images from your model every few epochs so you can see the progress.
  2. When you're satisfied with your results, rename the checkpoints file by adding _final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training. Note that your final submission zip will not include the checkpoints/ folder. This is OK.

The images you get should be colorful, with different backgrounds and poses.

In [16]:
import IPython.display

def post_epoch_fn(epoch, train_result, test_result, verbose):
    # Plot some samples if this is a verbose epoch
    if verbose:
        samples = vae.sample(n=5)
        fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    checkpoint_file = checkpoint_file_final
else:
    res = trainer.fit(dl_train, dl_test,
                      num_epochs=200, early_stopping=20, print_every=10,
                      checkpoints=checkpoint_file,
                      post_epoch_fn=post_epoch_fn)
    
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
--- EPOCH 1/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 1
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 2
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 3
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 4
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 5
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 6
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 7
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 8
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 9
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 10
--- EPOCH 11/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 11
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 12
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 13
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 14
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 15
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 16
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 17
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 18
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 19
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 20
--- EPOCH 21/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 21
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 22
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 23
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 24
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 25
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 26
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 27
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 28
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 29
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 30
--- EPOCH 31/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 31
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 32
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 33
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 34
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 35
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 36
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 37
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 38
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 39
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 40
--- EPOCH 41/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 41
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 42
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 43
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 44
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 45
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 46
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 47
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 48
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 49
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 50
--- EPOCH 51/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 51
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 52
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 53
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 54
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 55
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 56
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 57
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 58
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 59
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 60
--- EPOCH 61/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 61
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 62
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 63
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 64
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 65
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 66
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 67
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 68
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 69
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 70
--- EPOCH 71/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 71
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 72
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 73
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 74
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 75
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 76
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 77
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 78
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 79
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 80
--- EPOCH 81/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 81
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 82
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 83
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 84
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 85
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 86
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 87
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 88
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 89
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 90
--- EPOCH 91/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 91
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 92
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 93
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 94
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 95
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 96
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 97
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 98
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 99
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 100
--- EPOCH 101/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 101
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 102
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 103
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 104
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 105
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 106
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 107
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 108
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 109
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 110
--- EPOCH 111/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 111
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 112
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 113
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 114
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 115
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 116
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 117
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 118
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 119
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 120
--- EPOCH 121/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 121
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 122
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 123
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 124
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 125
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 126
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 127
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 128
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 129
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 130
--- EPOCH 131/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 131
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 132
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 133
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 134
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 135
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 136
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 137
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 138
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 139
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 140
--- EPOCH 141/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 141
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 142
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 143
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 144
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 145
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 146
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 147
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 148
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 149
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 150
--- EPOCH 151/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 151
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 152
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 153
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 154
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 155
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 156
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 157
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 158
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 159
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 160
--- EPOCH 161/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 161
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 162
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 163
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 164
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 165
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 166
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 167
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 168
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 169
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 170
--- EPOCH 171/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 171
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 172
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 173
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 174
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 175
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 176
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 177
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 178
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 179
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 180
--- EPOCH 181/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 181
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 182
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 183
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 184
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 185
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 186
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 187
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 188
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 189
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 190
--- EPOCH 191/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 191
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 192
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 193
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 194
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 195
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 196
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 197
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 198
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 199
--- EPOCH 200/200 ---
train_batch:   0%|          | 0/15 [00:00<?, ?it/s]
test_batch:   0%|          | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 200
*** Images Generated from best model:

Questions¶

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [17]:
from cs236781.answers import display_answer
import hw3.answers as answers

Question 1¶

What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.

In [18]:
display_answer(answers.part2_q1)

Your answer: The hyperparameter $\sigma^2$ is used to set the distance between the encoding and the mean (describes the allowed difference between the distance and the mean.) By using low sigma values, the images generated by the model are closer to the training data, that's because the model is closer to the mean and is more constrained by the data it has seen. This is in contrast to using high sigma values, which may produce images that differ from the learned data.

Question 2¶

  1. Explain the purpose of both parts of the VAE loss term - reconstruction loss and KL divergence loss.
  2. How is the latent-space distribution affected by the KL loss term?
  3. What's the benefit of this effect?
In [19]:
display_answer(answers.part2_q2)

Your answer: 1)Reconstruction Loss: Gives us a measure of how well the decoder reconstructs x. KL divergence loss: is a regularizer that measures how much information we lose when using q to represent p.

2)The effect of the KL loss on the latent-space distribution is as follows: the KL loss changes z_mu and z_sigma_2 given an instance of x by penalising the model to an inferior distribution of z.

3) The benefit of this effect lies in the improvement of the generation task, because it adds interpolations between classes and remove dicontinuities in the latent-space.

Question 3¶

In the formulation of the VAE loss, why do we start by maximizing the evidence distribution, $p(\bb{X})$?

In [20]:
display_answer(answers.part2_q3)

Your answer: In the formulation of the VAE loss, we start by maximizing the evidence distribution, $p(\bb{X})$ because this helps us in finding the probability distrubuion of the data. This means that maximizing $p(\bb{X})$ gives a propper aproximation of the actual distribuation of the data.

Question 4¶

In the VAE encoder, why do we model the log of the latent-space variance corresponding to an input, $\bb{\sigma}^2_{\bb{\alpha}}$, instead of directly modelling this variance?

In [21]:
display_answer(answers.part2_q4)

Your answer: We use the log function here because we want to change the problem from a multiplication of all the the probabilities to a summation of the lof of all of those probabilities. We can use the log because it is monitinically ascending, and so the maximal value won't change. We can assume that because each data we got to train the model, is sampeld by the actual distibution.

$$ \newcommand{\mat}[1]{\boldsymbol {#1}} \newcommand{\mattr}[1]{\boldsymbol {#1}^\top} \newcommand{\matinv}[1]{\boldsymbol {#1}^{-1}} \newcommand{\vec}[1]{\boldsymbol {#1}} \newcommand{\vectr}[1]{\boldsymbol {#1}^\top} \newcommand{\rvar}[1]{\mathrm {#1}} \newcommand{\rvec}[1]{\boldsymbol{\mathrm{#1}}} \newcommand{\diag}{\mathop{\mathrm {diag}}} \newcommand{\set}[1]{\mathbb {#1}} \newcommand{\norm}[1]{\left\lVert#1\right\rVert} \newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\bm}[1]{{\bf #1}} \newcommand{\bb}[1]{\bm{\mathrm{#1}}} $$

Part 3: Generative Adversarial Networks¶

In this part we will implement and train a generative adversarial network and apply it to the task of image generation.

In [1]:
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile

import numpy as np
import torch
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cpu

Obtaining the dataset¶

We'll use the same data as in Part 2.

But again, you can use a custom dataset, by editing the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.

In [2]:
import cs236781.plot as plot
import cs236781.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL

DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
    DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
    DATA_URL = CUSTOM_DATA_URL

_, dataset_dir = cs236781.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/kali/.pytorch-datasets/lfw-bush.zip exists, skipping download.
Extracting /home/kali/.pytorch-datasets/lfw-bush.zip...
Extracted 531 to /home/kali/.pytorch-datasets/lfw/George_W_Bush

Create a Dataset object that will load the extraced images:

In [3]:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder

im_size = 64
tf = T.Compose([
    # Resize to constant spatial dimensions
    T.Resize((im_size, im_size)),
    # PIL.Image -> torch.Tensor
    T.ToTensor(),
    # Dynamic range [0,1] -> [-1, 1]
    T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])

ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)

OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.

In [4]:
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
In [5]:
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)

test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])

Generative Adversarial Nets (GANs)¶

GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.

In a GAN model, two different neural networks compete against each other: A generator and a discriminator.

  • The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.

  • The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

Training GANs¶

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:

$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.

Model Implementation¶

We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.

TODO: Implement the Discriminator class in the hw3/gan.py module. If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.

In [6]:
import hw3.gan as gan

dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)

d0 = dsc(x0)
print(d0.shape)

test.assertSequenceEqual(d0.shape, (1,1))
Discriminator(
  (dsc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)
torch.Size([1, 1])

TODO: Implement the Generator class in the hw3/gan.py module. If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.

In [7]:
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)

z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)

test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
  (cnn): Sequential(
    (0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)
torch.Size([1, 3, 64, 64])

Loss Implementation¶

Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$

  • \mathbb{E} {\bb{x} \sim p(\bb{X}) } \log \Delta {\bb{\delta}}(\bb{x}) \, - \, \mathbb{E} {\bb{z} \sim p(\bb{Z}) } \log (1-\Delta {\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$

We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.

GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.

We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.

TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.

In [8]:
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)

y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10

loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)

test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)

Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$

which can also be seen as a cross-entropy term. This corresponds to "fooling" the discriminator; Notice that the gradient of the loss w.r.t $\bb{\gamma}$ using this expression also depends on $\bb{\delta}$.

TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.

In [9]:
from hw3.gan import generator_loss_fn
torch.manual_seed(42)

y_generated = torch.rand(20) * 10

loss = generator_loss_fn(y_generated, data_label=1)
print(loss)

test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-3)
tensor(0.0223)

Sampling¶

Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.

There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients (i.e., to be part of the Generator's computation graph).

TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.

In [10]:
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())

samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)

Training¶

Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.

As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)

TODO:

  1. Implement the train_batch function in the hw3/gan.py module.
  2. Tweak the hyperparameters in the part3_gan_hyperparams() function within the hw3/answers.py module.
In [11]:
import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams

torch.manual_seed(42)

# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']

# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape

# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)

# Optimizer
def create_optimizer(model_params, opt_params):
    opt_params = opt_params.copy()
    optimizer_type = opt_params['type']
    opt_params.pop('type')
    return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])

# Loss
def dsc_loss_fn(y_data, y_generated):
    return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])

def gen_loss_fn(y_generated):
    return gan.generator_loss_fn(y_generated, hp['data_label'])

# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
    os.remove(f'{checkpoint_file}.pt')

# Show hypers
print(hp)
{'batch_size': 8, 'z_dim': 100, 'data_label': 1, 'label_noise': 0.2, 'discriminator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'betas': (0.5, 0.999)}, 'generator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'betas': (0.5, 0.999)}}

TODO:

  1. Implement the save_checkpoint function in the hw3.gan module. You can decide on your own criterion regarding whether to save a checkpoint at the end of each epoch.
  2. Run the following block to train. It will sample some images from your model every few epochs so you can see the progress.
  3. When you're satisfied with your results, rename the checkpoints file by adding _final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training. Note that your final submission zip will not include the checkpoints/ folder. This is OK.
In [12]:
import IPython.display
import tqdm
from hw3.gan import train_batch, save_checkpoint

num_epochs = 100

if os.path.isfile(f'{checkpoint_file_final}.pt'):
    print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
    num_epochs = 0
    gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device,)
    checkpoint_file = checkpoint_file_final

try:
    dsc_avg_losses, gen_avg_losses = [], []
    for epoch_idx in range(num_epochs):
        # We'll accumulate batch losses and show an average once per epoch.
        dsc_losses, gen_losses = [], []
        print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')

        with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
            for batch_idx, (x_data, _) in enumerate(dl_train):
                x_data = x_data.to(device)
                dsc_loss, gen_loss = train_batch(
                    dsc, gen,
                    dsc_loss_fn, gen_loss_fn,
                    dsc_optimizer, gen_optimizer,
                    x_data)
                dsc_losses.append(dsc_loss)
                gen_losses.append(gen_loss)
                pbar.update()

        dsc_avg_losses.append(np.mean(dsc_losses))
        gen_avg_losses.append(np.mean(gen_losses))
        print(f'Discriminator loss: {dsc_avg_losses[-1]}')
        print(f'Generator loss:     {gen_avg_losses[-1]}')
        
        if save_checkpoint(gen, dsc_avg_losses, gen_avg_losses, checkpoint_file):
            print(f'Saved checkpoint.')
            

        samples = gen.sample(5, with_grad=False)
        fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
        IPython.display.display(fig)
        plt.close(fig)
except KeyboardInterrupt as e:
    print('\n *** Training interrupted by user')
--- EPOCH 1/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:39<00:00,  1.48s/it]
Discriminator loss: 0.2226154575986204
Generator loss:     10.399013006865088
--- EPOCH 2/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:51<00:00,  1.66s/it]
Discriminator loss: 0.32406483643424155
Generator loss:     10.796153659251198
--- EPOCH 3/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:40<00:00,  1.50s/it]
Discriminator loss: 0.6249632680538431
Generator loss:     7.866348905349845
--- EPOCH 4/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:53<00:00,  1.70s/it]
Discriminator loss: 0.7597085722346804
Generator loss:     4.140996360956733
--- EPOCH 5/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:04<00:00,  1.86s/it]
Discriminator loss: 0.7992192764780415
Generator loss:     4.1413471245053985
--- EPOCH 6/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:41<00:00,  2.40s/it]
Discriminator loss: 0.760242141227224
Generator loss:     3.8490008097976003
--- EPOCH 7/100 ---
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:25<00:00,  1.27s/it]
Discriminator loss: 0.6488763174014305
Generator loss:     4.048653104412022
--- EPOCH 8/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.692744659287716
Generator loss:     4.1238828772929175
--- EPOCH 9/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.06s/it]
Discriminator loss: 0.7140244154342964
Generator loss:     3.651231741727288
--- EPOCH 10/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.10s/it]
Discriminator loss: 0.6570371590109904
Generator loss:     4.1205917472269995
--- EPOCH 11/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:17<00:00,  1.16s/it]
Discriminator loss: 0.7654145942695105
Generator loss:     3.452788523773649
--- EPOCH 12/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.10s/it]
Discriminator loss: 0.6773354987155146
Generator loss:     3.9122310796780373
--- EPOCH 13/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.686968584558857
Generator loss:     3.6455368479685997
--- EPOCH 14/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.6736044735272428
Generator loss:     3.794376923077142
--- EPOCH 15/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:21<00:00,  1.21s/it]
Discriminator loss: 0.6158740751778902
Generator loss:     3.928209099306989
--- EPOCH 16/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.7055960969248815
Generator loss:     3.748022351691972
--- EPOCH 17/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.7214510603182351
Generator loss:     4.0355255336903815
--- EPOCH 18/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.5925487618504176
Generator loss:     3.592918805222013
--- EPOCH 19/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.6489447415319841
Generator loss:     3.7618876855764816
--- EPOCH 20/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.6137689489926865
Generator loss:     3.728848683300303
--- EPOCH 21/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:20<00:00,  1.20s/it]
Discriminator loss: 0.533989105373621
Generator loss:     4.0369778053084415
--- EPOCH 22/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.09s/it]
Discriminator loss: 0.6231796376184741
Generator loss:     4.205059199190852
--- EPOCH 23/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.5462197742577809
Generator loss:     3.6763485750155662
--- EPOCH 24/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00,  1.12s/it]
Discriminator loss: 0.5820786693870131
Generator loss:     3.934920588536049
--- EPOCH 25/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.48068115604457573
Generator loss:     3.8703388989861334
--- EPOCH 26/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.5212051019406141
Generator loss:     4.146393323122566
--- EPOCH 27/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.47505673365806467
Generator loss:     4.043247221121147
--- EPOCH 28/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.48975709870235246
Generator loss:     3.9668766687165444
Saved checkpoint.
--- EPOCH 29/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.12s/it]
Discriminator loss: 0.41483799612789013
Generator loss:     4.285565899379217
--- EPOCH 30/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.47531898019473945
Generator loss:     4.09228524165367
--- EPOCH 31/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.06s/it]
Discriminator loss: 0.4484067074676503
Generator loss:     4.203467078173339
--- EPOCH 32/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.4615772899184654
Generator loss:     4.129581930032417
--- EPOCH 33/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.34161930709187666
Generator loss:     4.388808421234586
--- EPOCH 34/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.4297226559259553
Generator loss:     4.408588472586959
--- EPOCH 35/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.38908996421899367
Generator loss:     4.403439705051593
--- EPOCH 36/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.5031480472812901
Generator loss:     4.302793388046435
--- EPOCH 37/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.3036700967532485
Generator loss:     4.256853963012126
--- EPOCH 38/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.44352790401942693
Generator loss:     4.935855774737116
--- EPOCH 39/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00,  1.13s/it]
Discriminator loss: 0.34675667462731474
Generator loss:     4.575066045149049
--- EPOCH 40/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.06s/it]
Discriminator loss: 0.40445677874915636
Generator loss:     4.601134561780674
--- EPOCH 41/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.35126391709295673
Generator loss:     4.461568499678996
--- EPOCH 42/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.31365142673698826
Generator loss:     4.84396594140067
--- EPOCH 43/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.33251275033203526
Generator loss:     4.469906568527222
--- EPOCH 44/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.279036117523019
Generator loss:     4.753364780055943
--- EPOCH 45/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.06s/it]
Discriminator loss: 0.42701762508767754
Generator loss:     5.096534817966063
--- EPOCH 46/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.40647909551191685
Generator loss:     4.550182789119322
--- EPOCH 47/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.24964745699970134
Generator loss:     4.565602247394732
--- EPOCH 48/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.09s/it]
Discriminator loss: 0.26656937849388196
Generator loss:     4.488548634657219
--- EPOCH 49/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.09s/it]
Discriminator loss: 0.2868878970951287
Generator loss:     5.13614484089524
--- EPOCH 50/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.06s/it]
Discriminator loss: 0.2893540797504916
Generator loss:     4.9686783374245485
--- EPOCH 51/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.4205938978537695
Generator loss:     5.042785594712442
--- EPOCH 52/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.39453655745444904
Generator loss:     5.26520986521422
--- EPOCH 53/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.09s/it]
Discriminator loss: 0.3967029097587315
Generator loss:     4.636836155137019
--- EPOCH 54/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.2517954114435324
Generator loss:     4.709473912395648
--- EPOCH 55/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.21090908240137704
Generator loss:     4.746104042921493
--- EPOCH 56/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.10s/it]
Discriminator loss: 0.2222719030704961
Generator loss:     5.1047546383160265
--- EPOCH 57/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.09s/it]
Discriminator loss: 0.2371100701892109
Generator loss:     5.161920148934891
--- EPOCH 58/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.09s/it]
Discriminator loss: 0.23250969593871884
Generator loss:     5.226165280413272
--- EPOCH 59/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.10s/it]
Discriminator loss: 0.1713197364735959
Generator loss:     5.229283902182508
--- EPOCH 60/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.09s/it]
Discriminator loss: 0.3744734670244046
Generator loss:     5.22563999802319
--- EPOCH 61/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00,  1.13s/it]
Discriminator loss: 0.21818367603109845
Generator loss:     5.034501529451626
--- EPOCH 62/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.10s/it]
Discriminator loss: 0.2961589781595255
Generator loss:     5.973364061384059
--- EPOCH 63/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.09s/it]
Discriminator loss: 0.24063466767321773
Generator loss:     4.941025922547525
--- EPOCH 64/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:16<00:00,  1.14s/it]
Discriminator loss: 0.23336398223442817
Generator loss:     4.79269886728543
Saved checkpoint.
--- EPOCH 65/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.2037594327564115
Generator loss:     5.494116355234118
--- EPOCH 66/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.28042146480127944
Generator loss:     5.429578503566002
--- EPOCH 67/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.2380366637857992
Generator loss:     5.586613768961892
--- EPOCH 68/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.24321508038538828
Generator loss:     5.506548023935574
--- EPOCH 69/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.19822439951683157
Generator loss:     5.6423968984119925
--- EPOCH 70/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.19567888513652246
Generator loss:     5.2913757936278385
--- EPOCH 71/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.216128641795089
Generator loss:     5.444126164735253
--- EPOCH 72/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.1359927723768042
Generator loss:     5.344446449137446
--- EPOCH 73/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.08s/it]
Discriminator loss: 0.20982019202922708
Generator loss:     5.672997513813759
--- EPOCH 74/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.16212283858834808
Generator loss:     5.464307156961356
--- EPOCH 75/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.36618489243868574
Generator loss:     6.051519253360691
--- EPOCH 76/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.17745266220907666
Generator loss:     4.925476590199257
--- EPOCH 77/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.06s/it]
Discriminator loss: 0.18689542805859402
Generator loss:     5.445575205247794
--- EPOCH 78/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.05s/it]
Discriminator loss: 0.1568165964155055
Generator loss:     5.459341437069337
--- EPOCH 79/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.16645404059829108
Generator loss:     5.654817463746712
--- EPOCH 80/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:09<00:00,  1.04s/it]
Discriminator loss: 0.24319110466028326
Generator loss:     5.824600757057987
--- EPOCH 81/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:19<00:00,  1.18s/it]
Discriminator loss: 0.1486799336080231
Generator loss:     5.589747959108495
--- EPOCH 82/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:28<00:00,  1.32s/it]
Discriminator loss: 0.1359346567917226
Generator loss:     5.59498687644503
--- EPOCH 83/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:16<00:00,  1.14s/it]
Discriminator loss: 0.3582810720623429
Generator loss:     6.1349239803072235
--- EPOCH 84/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.09s/it]
Discriminator loss: 0.21791031133772723
Generator loss:     5.550275245709206
--- EPOCH 85/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00,  1.09s/it]
Discriminator loss: 0.1460703068982754
Generator loss:     5.643578984844151
--- EPOCH 86/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.1680278676881719
Generator loss:     5.423344418184081
--- EPOCH 87/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.05s/it]
Discriminator loss: 0.13634364540452387
Generator loss:     6.287931374649503
--- EPOCH 88/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00,  1.05s/it]
Discriminator loss: 0.15707427296620696
Generator loss:     6.208204162654592
--- EPOCH 89/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00,  1.13s/it]
Discriminator loss: 0.23691063135195134
Generator loss:     6.708966758713793
--- EPOCH 90/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:18<00:00,  1.17s/it]
Discriminator loss: 0.1935963678215422
Generator loss:     5.781576846962545
--- EPOCH 91/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00,  1.13s/it]
Discriminator loss: 0.1911312595232209
Generator loss:     5.672198094538788
Saved checkpoint.
--- EPOCH 92/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:20<00:00,  1.20s/it]
Discriminator loss: 0.11797445981916208
Generator loss:     5.658509987503735
Saved checkpoint.
--- EPOCH 93/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00,  1.10s/it]
Discriminator loss: 0.06029647434436118
Generator loss:     5.692577867365595
--- EPOCH 94/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.21785901588346088
Generator loss:     5.82088153397859
--- EPOCH 95/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.07s/it]
Discriminator loss: 0.11724033395745861
Generator loss:     5.538199930048701
--- EPOCH 96/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00,  1.06s/it]
Discriminator loss: 0.12432931063335333
Generator loss:     6.089405600704364
--- EPOCH 97/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00,  1.11s/it]
Discriminator loss: 0.1921722201991882
Generator loss:     6.962398450766036
--- EPOCH 98/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:58<00:00,  1.76s/it]
Discriminator loss: 0.12687940124088704
Generator loss:     6.596022196670077
--- EPOCH 99/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:23<00:00,  2.14s/it]
Discriminator loss: 0.2868398177757192
Generator loss:     6.26818527392487
--- EPOCH 100/100 ---
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:32<00:00,  2.28s/it]
Discriminator loss: 0.1492465426450345
Generator loss:     6.200201539850947
In [13]:
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
    gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:

Questions¶

TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.

In [8]:
from cs236781.answers import display_answer
import hw3.answers as answers

Question 1¶

Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?

In [9]:
display_answer(answers.part3_q1)

Your answer: In the training phase, we only train the discriminator alone, and in this phase we sample the images accordingly. We don't want these samples to affect the gradient of the generator, so we need to separate these samples from the backpropagation process. This can happen even if we don't intend to. Therefore, when we train the generator and freeze the discriminator, we preserve these gradients to improve the sampling power of the generator.

Question 2¶

  1. When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?

  2. What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?

In [10]:
display_answer(answers.part3_q2)

Your answer: 1) We shouldn't decide to stop training just because the generator loss is below a certain threshold, because if we look at the results, we can see that a low loss rate doesn't mean that the generator produces sound images=. Loss is defined by the ability of the discriminator to detect fake images, it does not measure sample quality. Sometimes the discriminator is not very good and the generator produces bad samples, but these samples can fool the discriminator.

2) If the discriminator loss remains constant and the generator loss decreases, it means that the discriminator cannot correctly identify real and fake samples. Generator improved and created better samples.

Question 3¶

Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?

In [4]:
display_answer(answers.part3_q3)

Your answer: It can be said that the images we generate with VAE are smoother and more focused on human faces. If we compare it to the VAE, those generated by the GAN are more noisy and have multiple colors. This might be due to the differences in architecture and loss function between both networks. For example, if we compare the loss functions of these two: the VAE loss function is directly related to the dataset, unlike the GAN loss function, it is from a game theory perspective and has no direct relationship with the dataset, so the general picture related refers to the entire image, including the background and its colors. In the VAE dataset, we have a common face, and because of its architecture and care for mutual information in the input and decoded images, it preserves the common features in the resulting decoded images without preserving the background and its color .

In [ ]: